{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# Note: You can select between running the Notebook on \"CPU\" or \"GPU\"\n",
        "# Click \"Runtime > Change Runtime time\" and set \"GPU\""
      ],
      "metadata": {
        "id": "Kh7c1F1J_sD7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#Uncomment to install ydata-synthetic lib\n",
        "#!pip install ydata-synthetic"
      ],
      "metadata": {
        "id": "fwXSWiYu_tl0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Tabular Synthetic Data Generation with CTGAN\n",
        "- CTGAN - Implemented accordingly with the [paper](https://arxiv.org/pdf/1907.00503.pdf)\n",
        "- This notebook is an example of how to use CTGAN to generate synthetic tabular data with numeric and categorical features.\n",
        "\n",
        "## Dataset\n",
        "\n",
        "- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).\n"
      ],
      "metadata": {
        "id": "6T8gjToi_yKA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from pmlb import fetch_data\n",
        "\n",
        "from ydata_synthetic.synthesizers.regular import RegularSynthesizer\n",
        "from ydata_synthetic.synthesizers import ModelParameters, TrainParameters"
      ],
      "metadata": {
        "id": "Ix4gZ9iSCVZI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load the data"
      ],
      "metadata": {
        "id": "I0qyPwoECZ5x"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Load data\n",
        "data = fetch_data('adult')\n",
        "num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
        "cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
        "            'native-country', 'target']"
      ],
      "metadata": {
        "id": "YeFPnJVOMVqd"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define model and training parameters"
      ],
      "metadata": {
        "id": "m6-dt5hLCgxG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Defining the training parameters\n",
        "batch_size = 500\n",
        "epochs = 500+1\n",
        "learning_rate = 2e-4\n",
        "beta_1 = 0.5\n",
        "beta_2 = 0.9\n",
        "\n",
        "ctgan_args = ModelParameters(batch_size=batch_size,\n",
        "                             lr=learning_rate,\n",
        "                             betas=(beta_1, beta_2))\n",
        "\n",
        "train_args = TrainParameters(epochs=epochs)"
      ],
      "metadata": {
        "id": "9SsyBS2nMaSA"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Create and Train the CTGAN"
      ],
      "metadata": {
        "id": "68MoepO0Cpx6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "synth = RegularSynthesizer(modelname='ctgan', model_parameters=ctgan_args)\n",
        "synth.fit(data=data, train_arguments=train_args, num_cols=num_cols, cat_cols=cat_cols)"
      ],
      "metadata": {
        "id": "oIHMVgSZMg8_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Generate new synthetic data"
      ],
      "metadata": {
        "id": "xHK-SRPyDUin"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "synth_data = synth.sample(1000)\n",
        "print(synth_data)"
      ],
      "metadata": {
        "id": "0aa2g0RLMkqe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "01808aa4-a700-4385-e7df-b2f7abd162a0"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "           age  workclass         fnlwgt  education  education-num  \\\n",
            "0    38.753654          4  179993.565472          8           10.0   \n",
            "1    36.408844          4  245841.807958          9           10.0   \n",
            "2    56.251066          4  400895.076058         11           13.0   \n",
            "3    26.846605          4  240156.201048         11           10.0   \n",
            "4    29.083102          1    5601.059126         11            9.0   \n",
            "..         ...        ...            ...        ...            ...   \n",
            "995  79.281276          4   30664.183560          1           10.0   \n",
            "996  51.423132          4  414524.980527          1           10.0   \n",
            "997  17.342915          6  177716.451926         11           13.0   \n",
            "998  39.298867          4  132011.369567         15           12.0   \n",
            "999  46.977763          2   92662.371635          9           13.0   \n",
            "\n",
            "     marital-status  occupation  relationship  race  sex  capital-gain  \\\n",
            "0                 4           0             3     4    0     55.771499   \n",
            "1                 6           7             0     4    1    124.337939   \n",
            "2                 4           3             3     4    1     27.968087   \n",
            "3                 4           6             1     4    0     25.065678   \n",
            "4                 6           3             0     4    0    126.269337   \n",
            "..              ...         ...           ...   ...  ...           ...   \n",
            "995               2           0             3     4    1      4.393001   \n",
            "996               4           7             3     2    0     54.841598   \n",
            "997               4           4             4     4    0     99.394428   \n",
            "998               4          14             1     4    1     97.834797   \n",
            "999               4           8             1     4    0     51.258308   \n",
            "\n",
            "     capital-loss  hours-per-week  native-country  target  \n",
            "0       -1.271118       39.749641              39       1  \n",
            "1       -2.114950       44.488198              39       1  \n",
            "2        1.541738       40.042696              39       1  \n",
            "3        1.148560       39.952615              39       1  \n",
            "4       -1.786768       39.808085              39       0  \n",
            "..            ...             ...             ...     ...  \n",
            "995      0.224015       50.580637              39       1  \n",
            "996      1.319341        4.441194              39       1  \n",
            "997     -5.231663       39.779674              39       1  \n",
            "998      1.595817       39.731359              13       1  \n",
            "999      1.129814       39.838415              39       1  \n",
            "\n",
            "[1000 rows x 15 columns]\n"
          ]
        }
      ]
    }
  ]
}